Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean TF Bert #9788

Merged
merged 7 commits into from
Jan 27, 2021
Merged

Clean TF Bert #9788

merged 7 commits into from
Jan 27, 2021

Conversation

jplu
Copy link
Contributor

@jplu jplu commented Jan 25, 2021

What does this PR do?

This PR aims to clean the code base of BERT and the other models that depends of it because of the #Copied from.... It also clean the template accordingly to the same changes applied in BERT.

The other models will receive the same type of cleaning, but each model will have its own PR.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've just commented on the BERT file since you said the others are just adapted for the #Copied from.

It's great to add type annotations everywhere and clean the code like this! I have two suggestions since the type annotations/keyword arguments can sometimes make the code less readable:

  • Union[List[tf.Tensor], List[np.ndarray], Dict[str, tf.Tensor], Dict[str, np.ndarray], np.ndarray, tf.Tensor] is a mouthful, so it should be saved in a custom type we define int modeling_tf_common (named TFModelInputType for instance). This will make all call methods signature less daunting but it will still be shown in full in the docs.
  • Passing keyword arguments with their names is way better with one exception: when the arguments are just names x/y, I don't see the point since it makes the reader pause at "What is x/y?" instead of helping them understand.

loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
)
# make sure only labels that are not equal to -100
# are taken into account as loss
masked_lm_active_loss = tf.not_equal(tf.reshape(labels["labels"], (-1,)), -100)
masked_lm_active_loss = tf.not_equal(x=tf.reshape(tensor=labels["labels"], shape=(-1,)), y=-100)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm all for using keywords in keywords argument, but here, x and y are not very informative, so it makes the code less readable in this instance. I would leave as is for this part.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here in particular, the x and y don't help readability at all imo

Comment on lines 114 to 116
next_sentence_active_loss = tf.not_equal(
x=tf.reshape(tensor=labels["next_sentence_label"], shape=(-1,)), y=-100
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

@@ -339,18 +360,18 @@ def call(self, hidden_states, attention_mask=None, head_mask=None, output_attent

if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in TFBertModel call() function)
attention_scores = attention_scores + attention_mask
attention_scores = tf.add(x=attention_scores, y=attention_mask)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.


# Mask heads if we want to
if head_mask is not None:
attention_scores = attention_scores * head_mask
attention_scores = tf.multiply(x=attention_scores, y=head_mask)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

hidden_states = self.LayerNorm(hidden_states)
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.transform_act_fn(x=hidden_states)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, the x= does not bring anything.

training=False,
input_ids: Optional[
Union[
List[tf.Tensor], List[np.ndarray], Dict[str, tf.Tensor], Dict[str, np.ndarray], np.ndarray, tf.Tensor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

training=False,
input_ids: Optional[
Union[
List[tf.Tensor], List[np.ndarray], Dict[str, tf.Tensor], Dict[str, np.ndarray], np.ndarray, tf.Tensor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

training=False,
input_ids: Optional[
Union[
List[tf.Tensor], List[np.ndarray], Dict[str, tf.Tensor], Dict[str, np.ndarray], np.ndarray, tf.Tensor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

training=False,
input_ids: Optional[
Union[
List[tf.Tensor], List[np.ndarray], Dict[str, tf.Tensor], Dict[str, np.ndarray], np.ndarray, tf.Tensor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

training=False,
input_ids: Optional[
Union[
List[tf.Tensor], List[np.ndarray], Dict[str, tf.Tensor], Dict[str, np.ndarray], np.ndarray, tf.Tensor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool that you added type annotations, this makes it very clear!

One aspect I don't understand is the need to add all the named arguments to each method, even when those are supposed to be positional arguments. I think it makes the reading harder, as you'll see in my comments below.

Comment on lines 1310 to 1319
def serving_output(self, output):
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
def serving_output(self, output: TFTokenClassifierOutput) -> TFTokenClassifierOutput:
hs = tf.convert_to_tensor(value=output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(value=output.attentions) if self.config.output_attentions else None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why we're putting named arguments everywhere, even when they're not named arguments. For example here, I don't think it makes it any easier to read tf.convert_to_tensor if you specify value, on the contrary.

It makes it seem like this method could (or should) take other arguments, while it is not the case: it should take a single argument, and named parameters can be used to slightly modify the behavior.

Anyway this is a nit, as it doesn't change anything and you've already switched everything, but I don't understand this choice

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What you are saying makes sense, I will rework this part of the PR 👍

loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
)
# make sure only labels that are not equal to -100
# are taken into account as loss
masked_lm_active_loss = tf.not_equal(tf.reshape(labels["labels"], (-1,)), -100)
masked_lm_active_loss = tf.not_equal(x=tf.reshape(tensor=labels["labels"], shape=(-1,)), y=-100)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here in particular, the x and y don't help readability at all imo

Comment on lines 904 to 959
self.bert = TFBertMainLayer(config, name="bert")
self.bert = TFBertMainLayer(config=config, name="bert")

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this example is exactly why I dislike this change. Before the change, we understood that TFBertMainLayer took one parameter, the configuration, and other named parameters to configure it further (for example TF-specific arguments like name, training, others). It's very clear that to initialize it we're only going to need the configuration.

Now, with this change, I think differently for two reasons:

  • It looks like the TFBertMainLayer has no required arguments. The configuration was passed to it, but it doesn't seem necessary to me. I would even argue it's now on the same level as the name parameter, which is not true.
  • That means that if I want to configure it, then I'll need to go and take a look at the signature to understand what arguments I may be unaware of that would change the TFBertMainLayer's initialization/behavior.

While beforehand the signature was very clear, now it isn't.

@jplu
Copy link
Contributor Author

jplu commented Jan 26, 2021

I fully rework the keywords addition part to keep only those that seemed the most meaningful.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed all changes and this looks great to me! Thanks for making all the adjustments!

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, clean changes! Thanks for iterating!

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes look good to me & agree with previous comments from @LysandreJik and @sgugger. Thanks a lot for your work here!

The failing ci test is related to a bug I hoped to have fixed yesterday. Before merging could you please merge master into this PR or rebase this PR to see if the circle is not red anymore?

@jplu jplu merged commit 4adbdce into huggingface:master Jan 27, 2021
@jplu jplu deleted the clean-tf-bert branch January 27, 2021 12:25
Qbiwan pushed a commit to Qbiwan/transformers that referenced this pull request Jan 31, 2021
* Start cleaning BERT

* Clean BERT and all those depends of it

* Fix attribute name

* Apply style

* Apply Sylvain's comments

* Apply Lysandre's comments

* remove unused import
@LysandreJik LysandreJik mentioned this pull request Feb 22, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants